# -*- coding: UTF-8 -*-
import logging
import pymemcache

from typing import List, Tuple

from django.core.checks.security.base import check_secret_key

from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
from sql.models import SqlWorkflow

logger = logging.getLogger("default")


class MemcachedEngine(EngineBase):
    test_query = "stats"
    name = "Memcached"
    info = "Memcached engine"

    def __init__(self, instance=None):
        super().__init__(instance=instance)
        # 用于存储多个节点连接: db_name -> conn
        # 如果 instance.host 使用 , 分割
        self.nodes = {}

        if not instance:
            return

        for i, host in enumerate(instance.host.split(",")):
            db_name = f"Node - {i}"
            self.nodes[db_name] = host.strip()

    def get_connection(self, db_name=None):
        db_name = db_name or "Node - 0"

        if db_name not in self.nodes:
            logger.warning(f"Memcached节点 {db_name} 不存在，使用默认节点 {db_name}")
            raise Exception(f"Memcached节点 {db_name} 不存在")

        node_host = self.nodes[db_name]

        try:
            conn = pymemcache.Client(
                server=(node_host, self.port), connect_timeout=10.0, timeout=10.0
            )
            return conn
        except Exception as e:
            raise Exception(f"连接Memcached节点 {node_host} 失败: {str(e)}")

    def test_connection(self):
        """测试实例链接是否正常"""
        try:
            conn = self.get_connection(None)
            # 使用 version 命令测试
            version = conn.version()
            if version:
                return ResultSet(
                    rows=[[f"连接成功，版本: {version}"]], column_list=["状态"]
                )
        except Exception as e:
            logger.error(f"测试连接失败: {str(e)}")
            raise Exception(f"测试连接失败: {str(e)}")

    def get_all_databases(self):
        """获取所有可用节点，将节点作为"数据库"返回"""
        result_set = ResultSet(column_list=["节点"], rows=[])
        try:
            for db_name in self.nodes:
                result_set.rows.append([db_name])
            return result_set
        except Exception as e:
            logger.error(f"获取所有节点失败: {str(e)}")
            raise Exception(f"获取所有节点失败: {str(e)}")

    def get_all_tables(self, db_name, **kwargs):
        return ResultSet(rows=[])

    # 修改后的 query 方法
    def query(
        self,
        db_name=None,
        sql="",
        limit_num=0,
        close_conn=True,
        parameters=None,
        **kwargs,
    ):
        """实际查询 返回一个ResultSet，采用cmd table驱动模式"""
        result_set = ResultSet(full_sql=sql)

        try:
            conn = self.get_connection(db_name)
            result_set = _handle_cmd(conn, sql)
        except Exception as e:
            logger.error(f"查询执行失败: {str(e)}")
            result_set.error = str(e)
            result_set.rows = [[f"错误: {str(e)}"]]
        finally:
            if close_conn:
                # 只关闭默认连接，保留节点连接
                if self.conn:
                    self.conn = None
                # 不关闭节点连接，因为可能会在后续查询中使用

        return result_set

    def query_check(self, db_name=None, sql=""):
        """查询语句的检查、注释去除、切分, 返回一个字典 {'bad_query': bool, 'filtered_sql': str}"""
        # 简单的SQL语法检查

        cmd, cmd_args = _parse_cmd_args(sql)
        allowed_commands = [
            "version",
            "get",
            "gets",
        ]

        if cmd not in allowed_commands:
            return {
                "bad_query": True,
                "filtered_sql": sql,
                "msg": "仅支持 (version, get, gets) 命令",
            }

        return {"bad_query": False, "filtered_sql": sql}

    def execute(self, db_name=None, sql="", **kwargs):
        execute_result = ReviewSet(full_sql=sql)

        try:
            conn = self.get_connection(db_name)
            cmd_result = _handle_cmd(conn, sql)

            assert len(cmd_result.rows) == 1, "命令执行结果行数不是1"
            assert len(cmd_result.rows[0]) == 1, "命令执行结果列数不是1"

            if cmd_result.rows[0][0] == "FAIL":
                execute_result.rows.append(
                    ReviewResult(
                        id=1,
                        affected_rows=0,
                        sql=sql,
                        stage="Execute",
                        stagestatus="Fail",
                    )
                )
            else:
                execute_result.rows.append(
                    ReviewResult(
                        id=1,
                        affected_rows=1,
                        sql=sql,
                        stage="Execute",
                        stagestatus="Success",
                    )
                )

            execute_result.affected_rows = cmd_result.affected_rows
            execute_result.error = cmd_result.error
        except Exception as e:
            logger.error(f"执行语句失败: {str(e)}")
            execute_result.error = str(e)
            execute_result.rows = [{"error": str(e)}]

        return execute_result

    def execute_check(self, db_name=None, sql=""):
        """执行语句的检查"""
        check_result = ReviewSet(full_sql=sql)

        allowed_commands = [
            "set",
            "delete",
            "incr",
            "decr",
            "touch",
        ]
        cmd, cmd_args = _parse_cmd_args(sql)

        if cmd not in allowed_commands:
            check_result.error_count += 1
            check_result.error = f"不支持的命令: {cmd}"
            check_result.rows = [
                ReviewResult(
                    id=1,
                    affected_rows=0,
                    sql=sql,
                    stage="Check",
                    stagestatus="Fail",
                    errlevel=2,
                    errormessage=f"不支持的命令: {cmd}",
                )
            ]
        else:
            check_result.rows = [
                ReviewResult(
                    id=1,
                    affected_rows=1,
                    sql=sql,
                    stage="Check",
                    stagestatus="Success",
                )
            ]
            check_result.checked = True

        return check_result

    def execute_workflow(self, workflow: SqlWorkflow):
        """执行上线单，返回Review set"""
        return self.execute(
            db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content
        )

    def get_execute_percentage(self):
        """获取执行进度"""
        return 100

    @property
    def server_version(self):
        """返回引擎服务器版本"""
        try:
            conn = self.get_connection()
            version = conn.version()
            # 尝试解析版本号为tuple
            parts = str(version).split(".")
            version_tuple = tuple(
                int(part) if part.isdigit() else 0 for part in parts[:3]
            )
            return version_tuple
        except Exception as e:
            logger.error(f"获取Memcached版本失败: {str(e)}")
            return tuple()


# 命令处理函数


def _handle_get(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
    """
    处理get命令: get <key>
    """

    result_set = ResultSet(full_sql=sql)

    if len(cmd_args) < 1:
        raise Exception("get命令格式错误")

    try:
        key = cmd_args[0].strip()
        value = conn.get(key)
        result_set.column_list = ["值"]
        result_set.rows = [[value if value is not None else "None"]]
    except Exception as e:
        raise Exception(f"get命令执行失败: {str(e)}")

    result_set.affected_rows = len(result_set.rows)
    return result_set


def _handle_set(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
    """
    处理set命令: set <key> <value> [expiry]
    """

    result_set = ResultSet(full_sql=sql)

    if len(cmd_args) < 2:
        raise Exception("set命令格式错误")

    try:
        key = cmd_args[0].strip()
        value = cmd_args[1].strip()
        expiry = int(cmd_args[2].strip()) if len(cmd_args) > 2 else 0
        ok = conn.set(key, value, expire=expiry)
        result_set.rows = [["OK"] if ok else ["FAIL"]]
        result_set.column_list = ["状态"]
    except Exception as e:
        raise Exception(f"set命令执行失败: {str(e)}")

    result_set.affected_rows = len(result_set.rows)
    return result_set


def _handle_delete(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
    """
    处理delete命令: delete <key>
    """

    result_set = ResultSet(full_sql=sql)

    if len(cmd_args) < 1:
        raise Exception("delete命令格式错误")

    try:
        key = cmd_args[0].strip()
        ok = conn.delete(key)
        result_set.rows = [["OK"] if ok else ["FAIL"]]
        result_set.column_list = ["状态"]
    except Exception as e:
        raise Exception(f"delete命令执行失败: {str(e)}")

    result_set.affected_rows = len(result_set.rows)
    return result_set


def _handle_version(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
    """
    处理version命令: version
    """

    result_set = ResultSet(full_sql=sql)
    version = conn.version()
    result_set.rows = [[version]]
    result_set.column_list = ["版本"]
    result_set.affected_rows = 1
    return result_set


def _handle_gets(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
    """
    处理gets命令: gets <key1> <key2>
    """

    result_set = ResultSet(full_sql=sql)

    if len(cmd_args) < 1:
        raise Exception("gets命令格式错误")

    try:
        keys = [v.strip() for v in cmd_args]
        values = conn.gets_many(keys)
        result_set.column_list = ["键", "值", "CAS"]
        for key, (value, cas) in values.items():
            result_set.rows.append([key, value if value is not None else "None", cas])
    except Exception as e:
        raise Exception(f"gets命令执行失败: {str(e)}")

    result_set.affected_rows = len(result_set.rows)
    return result_set


def _handle_incr(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
    """
    处理incr命令: incr <key> [value]
    """

    result_set = ResultSet(full_sql=sql)

    if len(cmd_args) < 1:
        raise Exception("incr命令格式错误")
    try:
        key = cmd_args[0].strip()
        value = int(cmd_args[1].strip()) if len(cmd_args) > 1 else 1
        result = conn.incr(key, value)
        result_set.rows = [[str(result) if result is not None else "FAIL"]]
        result_set.column_list = ["结果"]
    except Exception as e:
        raise Exception(f"incr命令执行失败: {str(e)}")

    result_set.affected_rows = 1
    return result_set


def _handle_decr(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
    """
    处理decr命令: decr <key> [value]
    """
    result_set = ResultSet(full_sql=sql)

    if len(cmd_args) < 1:
        raise Exception("decr命令格式错误")
    try:
        key = cmd_args[0].strip()
        value = int(cmd_args[1].strip()) if len(cmd_args) > 1 else 1
        result = conn.decr(key, value)
        result_set.rows = [[str(result) if result is not None else "FAIL"]]
        result_set.column_list = ["结果"]
    except Exception as e:
        raise Exception(f"decr命令执行失败: {str(e)}")

    result_set.affected_rows = 1
    return result_set


def _handle_touch(conn: pymemcache.Client, sql: str, cmd_args: List[str]):
    """
    处理touch命令: touch <key> <expiry>
    """

    result_set = ResultSet(full_sql=sql)

    if len(cmd_args) < 2:
        raise Exception("touch命令格式错误")

    try:
        key = cmd_args[0].strip()
        expiry = int(cmd_args[1].strip())
        ok = conn.touch(key, expire=expiry)
        result_set.rows = [["OK"] if ok else ["FAIL"]]
        result_set.column_list = ["状态"]
    except Exception as e:
        raise Exception(f"touch命令执行失败: {str(e)}")

    result_set.affected_rows = 1
    return result_set


# 命令处理函数映射表
cmd_handlers = {
    "get": _handle_get,
    "set": _handle_set,
    "delete": _handle_delete,
    "version": _handle_version,
    "gets": _handle_gets,
    "incr": _handle_incr,
    "decr": _handle_decr,
    "touch": _handle_touch,
}


def _parse_cmd_args(sql: str) -> Tuple[str, List[str]]:
    """
    解析命令参数
    """
    cmd = sql.split(" ")[0].strip().lower()
    cmd_args = sql.split(" ")[1:]
    return cmd, cmd_args


def _handle_cmd(conn: pymemcache.Client, sql: str):
    """
    处理命令
    """

    # 简单解析SQL命令
    sql = sql.strip().lower()
    if not sql:
        raise Exception("空SQL语句")

    # 提取命令名称
    parts = sql.split(" ")
    cmd = parts[0]
    cmd_args = parts[1:]

    if cmd not in cmd_handlers:
        raise Exception(f"不支持的命令: {cmd}")

    return cmd_handlers[cmd](conn, sql, cmd_args)
